{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 数据增强\n",
    "    image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(\n",
    "        rescale = 所有数据将乘以该数值\n",
    "        rotation_range = 随机旋转角度数范围\n",
    "        width_shift_range = 随机宽度偏移量\n",
    "        height_shift_range = 随机高度偏移量\n",
    "        水平翻转：horizontal_flip = 是否随机水平翻转\n",
    "        随机缩放：zoom_range = 随机缩放的范围 [1-n，1+n] )\n",
    "    \n",
    "    image_gen_train.fit(x_train)   \n",
    "\n",
    "例：\n",
    "\n",
    "    image_gen_train = ImageDataGenerator(\n",
    "        rescale=1. / 1., # 如为图像，分母为255时，可归至0～1\n",
    "        rotation_range=45, # 随机45度旋转\n",
    "        width_shift_range=.15, # 宽度偏移\n",
    "        height_shift_range=.15, # 高度偏移\n",
    "        horizontal_flip=False, # 水平翻转\n",
    "        zoom_range=0.5 # 将图像随机缩放阈量50％)\n",
    "    image_gen_train.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "fashion = tf.keras.datasets.fashion_mnist\n",
    "(x_train, y_train), (x_test, y_test) = fashion.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "\n",
    "x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度，使数据和网络结构匹配\n",
    "\n",
    "image_gen_train = ImageDataGenerator(\n",
    "    rescale=1. / 1.,  # 如为图像，分母为255时，可归至0～1\n",
    "    rotation_range=45,  # 随机45度旋转\n",
    "    width_shift_range=.15,  # 宽度偏移\n",
    "    height_shift_range=.15,  # 高度偏移\n",
    "    horizontal_flip=True,  # 水平翻转\n",
    "    zoom_range=0.5  # 将图像随机缩放阈量50％\n",
    ")\n",
    "image_gen_train.fit(x_train)\n",
    "\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dense(10, activation='softmax')\n",
    "])\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),\n",
    "              metrics=['sparse_categorical_accuracy'])\n",
    "\n",
    "model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),\n",
    "          validation_freq=1)\n",
    "model.summary()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
